#!/usr/bin/env python
# coding: utf-8
"""
Computes approximating modules for immuno.

Input parameters :
- idataset = 1 or 2 : the dataset index
- res: int : the grid resolution (on which we compute the alpha complex)
- num : int : the number of modules approximating the final one
"""


# In[13]:

print("Loading libraries...")
import numpy as np
import matplotlib.pyplot as plt
import mma
from classif_helper import *
import gudhi as gd
from sklearn.neighbors import KernelDensity
from os.path import expanduser
from pandas import read_csv
from multiprocessing import cpu_count
from sys import argv
from copy import deepcopy
idataset = int(argv[1])
if (idataset != 1 and idataset !=2):
	print("bad argument")
	exit()
res = int(argv[2])
num = int(argv[3])
print("Arguments :", *argv)
# In[14]:


DATASET_PATH = expanduser("~/Datasets/")


# # Dataset generation

# ## Retrieves data

# In[15]:


def get_immuno(i=idataset):
	immu_dataset = read_csv(DATASET_PATH+f"LargeHypoxicRegion{i}.csv")
	X = immu_dataset['x']
	Y = immu_dataset['y']
	labels = LabelEncoder().fit_transform(immu_dataset['Celltype'])
	return X,Y, labels
print("Chosen dataset : " + DATASET_PATH+f"LargeHypoxicRegion{idataset}.csv")

# In[16]:

print("Printing dataset...", flush=True)
X,Y, labels = get_immuno()
# np.random.seed(42)
indices = np.random.permutation(len(X))
X, Y, labels = X[indices], Y[indices], labels[indices] # points are ordered by label, we don't want this information
plt.scatter(X,Y,c=labels, s=0.5);
plt.savefig(f"images/scatter_immuno{idataset}.svg")
plt.clf()
len(X)

# In[17]:


X0 = np.array([[X[i], Y[i]] for i in range(len(X)) if labels[i] == 0])
X1 = np.array([[X[i], Y[i]] for i in range(len(X)) if labels[i] == 1])
X2 = np.array([[X[i], Y[i]] for i in range(len(X)) if labels[i] == 2])
print("Shapes :", X0.shape, X1.shape, X2.shape, flush = True)


# ## Makes a triangulation of this rectangle

# In[18]:

print("Computing Alpha Complex", flush=True)
resolution = [res,res]
m,M = np.min(np.append(X,Y)), np.max(np.append(X,Y))
grid = np.array([[m+i*(M-m)/resolution[0], m+j*(M-m)/resolution[1]] for j in range(resolution[1]) for i in range(resolution[0])])
alphacplx_grid = gd.AlphaComplex(points=grid)
simplextree_grid = alphacplx_grid.create_simplex_tree()
points_grid = np.array([alphacplx_grid.get_point(i) for i in range(simplextree_grid.num_vertices())]) # Alpha complex may reorder points
plt.scatter(points_grid[:,0], points_grid[:,1], s=0.5);
print("Num simplicies : ", simplextree_grid.num_simplices(), "Num pts :", len(points_grid), flush=True)

# ## Bifiltration definition

# In[19]:


def get_kde(to_fit,to_estimate, **kwargs):
	kde = KernelDensity(kernel=kwargs.get("kde_kernel", 'gaussian'), bandwidth=kwargs.get("kde_bandwidth",0.5)).fit(to_fit)
	codensity_filtration = -np.array(kde.score_samples(to_estimate))
	return codensity_filtration
def get_bf(k:int=len(X), **kwargs):
	x0 = [[X[i], Y[i]] for i in range(k) if labels[i] == 0]
	x1 = [[X[i], Y[i]] for i in range(k) if labels[i] == 1]
	F1 = get_kde(x0, points_grid, **params)
	F2 = get_kde(x1, points_grid, **params)
	simplextree = mma.SimplexTreeMulti(simplextree_grid, num_parameters=2)
	simplextree.fill_lowerstar(F1, parameter=0)
	simplextree.fill_lowerstar(F2, parameter=1)
	# simplextree.collapse_edges(num=100, ignore_warning=True)
	# simplextree.expansion(2)
	return simplextree

# In[20]:

print("Defining parameters grid", flush=True)
#a,b= get_kde(X0, X1, kde_bandwidth=100),get_kde(X1, X0, kde_bandwidth=1)
#bm, bM = np.min(np.append(a,b)), np.max(np.append(a,b))
#box = [[bm-1,bm-1], [bM+1, bM+1]]
#print(box)
params = {
	"n_jobs":cpu_count(),"K":len(X),
	"kmin":10,"kmax":len(X),"nsamples":num,
	"precision":10,
	"degree":[0,1],"resolution":[200,200],
	"kde_bandwidth":0.5,
	"box":[[0,0], [20000,20000]],
	"kde_kernel": "gaussian",
	"normalize":1,
	"bandwidth":100,
	"ps":[0,1,2, np.inf],"threshold":1000,
	"flatten":True,
}
print(params)
box =params["box"]

# In[21]:

#print("Computing last image")
#last_img = compute_img(get_bf(**params),plot=True,size=5,**params)


# In[22]:
start = params["kmin"]
stop = params["kmax"]
num = params["nsamples"]
lin_it = np.linspace(start=start, stop=stop, num=num, dtype=int)
log_it = np.logspace(start=np.log10(start), stop=np.log10(stop), num=num, dtype=int)
iterator = np.unique(np.concatenate([lin_it, log_it]))
iterator.sort()
with open(f"modules/immuno/iterator{idataset}_kdebdw{params['kde_bandwidth']}_res{res}_num{num}.np", "wb") as f:
    np.save(f,iterator) 
print("Computing approximation modules ... ", flush=True)
compute_mods(iterator, get_bf, dump=True,save=f"modules/immuno/module{idataset}_kdebdw{params['kde_bandwidth']}_res{res}_num{num}_", **params) 

print("Done !")




